{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Urban Driver\n",
    "\n",
    "In this notebook you are going to train the planner introduced in [Urban Driver: Learning to Drive from Real-world Demonstrations Using Policy Gradients](https://openreview.net/pdf?id=ibktAcINCaj).\n",
    "\n",
    "You will train your model using the Lyft Prediction Dataset and [L5Kit](https://github.com/lyft/l5kit).\n",
    "**Before starting, please download the [Lyft L5 Prediction Dataset 2020](https://self-driving.lyft.com/level5/prediction/) and follow [the instructions](https://github.com/lyft/l5kit#download-the-datasets) to correctly organise it.**\n",
    "\n",
    "### Model\n",
    "\n",
    "From the paper:\n",
    "```\n",
    "We use a graph neural network for parametrizing our policy.\n",
    "It combines a PointNet-like architecture for local inputs processing followed by an attention mechanism for global reasoning. In contrast to VectorNet, we use points instead of vectors. Given the set of points corresponding to each input element, we employ 3 PointNet layers to calculate a 128-dimensional feature descriptor. Subsequently, a single layer of scaled dot-product attention performs global feature aggregation, yielding the predicted trajectory. [...] In total, our model contains around 3.5 million trainable parameters, and training takes 30h on 32 Tesla V100 GPUs. For more details we refer to Appendix C.\n",
    "```\n",
    "We also report a diagram of the full model:\n",
    "\n",
    "![model](../../docs/images/urban_driver/model.svg)\n",
    "\n",
    "\n",
    "#### Inputs\n",
    "Urban Driver is based on a vectorized representation of the world. From the paper:\n",
    "```\n",
    "We define the state as the whole set of static and dynamic elements the model receive as input. Each element is composed of a variable number of points, which can represent both time (e.g. for agents) and space (e.g. for lanes). The number of features per point depends on the element type. We pad all features to a fixed size F to ensure they can share the first fully connected layer. We include all elements up to the listed maximal number in a circular FOV of radius 35m around the SDV. Note that for performance and simplicity we only execute this query once, and then unroll within this world state.\n",
    "```\n",
    "\n",
    "In more details:\n",
    "\n",
    "\n",
    "| State element(s) | Elements per state | Points per element | Point features description                                                               |\n",
    "|------------------|--------------------|--------------------|------------------------------------------------------------------------------------------|\n",
    "| SDV              | 1                  | 4                  | SDV's X, Y and yaw pose of the current time step, as well as previous timesteps          |\n",
    "| Agents           | up to 30           | 4                  | other agents' X, Y and yaw poses of the current time step, as well as previous timesteps |\n",
    "| Lanes mid        | up to 30           | 20                 | interpolated X, Y points of the lanes' center lines, with optional traffic light signals |\n",
    "| Lanes left       | up to 30           | 20                 | interpolated X, Y points of the left lane boundaries                                     |\n",
    "| Lanes right      | up to 30           | 20                 | interpolated X, Y points of the right lane boundaries                                    |\n",
    "| Crosswalks       | up to 20           | up to 20           | crosswalks' polygon boundaries - X, Y                                                    |\n",
    "\n",
    "\n",
    "\n",
    "#### Outputs\n",
    "Urban Driver outputs the next positions and orientations of the SDV. Each timestep is a tuple consisting of `(X, Y, yaw)`.\n",
    "\n",
    "### Training in closed loop\n",
    "One of the main features of Urban Driver is how it is trained; from the paper:\n",
    "```\n",
    "[...] we then train a policy network in closed-loop employing policy gradients.\n",
    "We train our proposed method on 100 hours of expert demonstrations on urban roads and show that it learns complex driving policies that generalize well and can perform a variety of driving maneuvers\n",
    "```\n",
    "\n",
    "When training in closed-loop, the model does not predict all timesteps at once, but instead predicts one action at a time and use that action to perform a step in the surrounding environment before repeating the process. This allows the model to capture how the environment evolves as it takes decisions.\n",
    "\n",
    "Compare Figure 3 from the original paper:\n",
    "\n",
    "![model](../../docs/images/urban_driver/method.png)\n",
    "\n",
    "```\n",
    "One iteration of policy gradient update. Given a real-world expert trajectory we sample a policy state by unrolling the policy for T steps. We then compute optimal policy update by backpropagation through time.\n",
    "```\n",
    "\n",
    "\n",
    "Closed-loop training has two major advantages:\n",
    "- reducing the domain shift between training and evaluation;\n",
    "- replacing hand-crafted off-policy perturbations with on-policy perturbations generated by the model itself\n",
    "\n",
    "Again from the paper:\n",
    "```\n",
    "[...] reports performance when all methods are trained to optimize the imitation loss alone. Behavioral cloning yields a high number of trajectory errors and collisions. This is expected, as this approach is known to suffer from the issue of covariate shift \n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from tempfile import gettempdir\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import ChunkedDataset, LocalDataManager\n",
    "from l5kit.dataset import EgoDatasetVectorized\n",
    "from l5kit.planning.vectorized.closed_loop_model import VectorizedUnrollModel\n",
    "from l5kit.planning.vectorized.open_loop_model import VectorizedModel\n",
    "from l5kit.vectorization.vectorizer_builder import build_vectorizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (rasteriser, training params...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Download L5 Sample Dataset and install L5Kit\n",
    "import os\n",
    "RunningInCOLAB = 'google.colab' in str(get_ipython())\n",
    "if RunningInCOLAB:\n",
    "    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q\n",
    "    !sh ./setup_notebook_colab.sh\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = open(\"./dataset_dir.txt\", \"r\").read().strip()\n",
    "else:\n",
    "    print(\"Not running in Google Colab.\")\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/l5kit_data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dm = LocalDataManager(None)\n",
    "# get config\n",
    "cfg = load_config_data(\"./config.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== INIT DATASET\n",
    "train_zarr = ChunkedDataset(dm.require(cfg[\"train_data_loader\"][\"key\"])).open()\n",
    "\n",
    "vectorizer = build_vectorizer(cfg, dm)\n",
    "train_dataset = EgoDatasetVectorized(cfg, train_zarr, vectorizer)\n",
    "print(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Choosing the model\n",
    "\n",
    "You can use this notebook to train not only Urban Driver, but also several ablations included in the paper. We provide the following options:\n",
    "- Urban Driver: this is the default value and the model presented in the paper;\n",
    "- Urban Driver without BPTT: this is an ablation of Urban Driver where we detach the gradient between steps;\n",
    "- Open Loop Planner: this is vectorized model trained with simple behavioural cloning;\n",
    "\n",
    "We now detail which config changes are necessary to obtain the baseline models - note that these are also required when loading pre-trained models for evaluation:\n",
    "\n",
    "| Model     | Changes to config |\n",
    "| ----------- | ----------- |\n",
    "| Open Loop Planner  (BC-perturb)   | - history_num_frames_ego: 0 |\n",
    "| Open Loop Planner with Ego History  (BC-perturb) | None  |\n",
    "| Urban Driver without BPTT (MS Prediction)   | - future_num_frames: 32 <br/> - warmup_num_frames: 20|\n",
    "| Urban Driver without BPTT (Ours)   | - future_num_frames: 32 <br/> - warmup_num_frames: 20 <br/> - detach_unroll: False|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "URBAN_DRIVER = \"Urban Driver\"\n",
    "OPEN_LOOP_PLANNER = \"Open Loop Planner\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = URBAN_DRIVER\n",
    "\n",
    "weights_scaling = [1.0, 1.0, 1.0]\n",
    "\n",
    "_num_predicted_frames = cfg[\"model_params\"][\"future_num_frames\"]\n",
    "_num_predicted_params = len(weights_scaling)\n",
    "\n",
    "\n",
    "if model_name == URBAN_DRIVER:\n",
    "    model = VectorizedUnrollModel(\n",
    "        history_num_frames_ego=cfg[\"model_params\"][\"history_num_frames_ego\"],\n",
    "        history_num_frames_agents=cfg[\"model_params\"][\"history_num_frames_agents\"],\n",
    "        num_targets=_num_predicted_params * _num_predicted_frames,\n",
    "        weights_scaling=weights_scaling,\n",
    "        criterion=nn.L1Loss(reduction=\"none\"),\n",
    "        global_head_dropout=cfg[\"model_params\"][\"global_head_dropout\"],\n",
    "        disable_other_agents=cfg[\"model_params\"][\"disable_other_agents\"],\n",
    "        disable_map=cfg[\"model_params\"][\"disable_map\"],\n",
    "        disable_lane_boundaries=cfg[\"model_params\"][\"disable_lane_boundaries\"],\n",
    "        detach_unroll=cfg[\"model_params\"][\"detach_unroll\"],\n",
    "        warmup_num_frames=cfg[\"model_params\"][\"warmup_num_frames\"],\n",
    "        discount_factor=cfg[\"model_params\"][\"discount_factor\"],\n",
    "    )\n",
    "\n",
    "elif model_name == OPEN_LOOP_PLANNER:\n",
    "    model = VectorizedModel(\n",
    "        history_num_frames_ego=cfg[\"model_params\"][\"history_num_frames_ego\"],\n",
    "        history_num_frames_agents=cfg[\"model_params\"][\"history_num_frames_agents\"],\n",
    "        num_targets=_num_predicted_params * _num_predicted_frames,\n",
    "        weights_scaling=weights_scaling,\n",
    "        criterion=nn.L1Loss(reduction=\"none\"),\n",
    "        global_head_dropout=cfg[\"model_params\"][\"global_head_dropout\"],\n",
    "        disable_other_agents=cfg[\"model_params\"][\"disable_other_agents\"],\n",
    "        disable_map=cfg[\"model_params\"][\"disable_map\"],\n",
    "        disable_lane_boundaries=cfg[\"model_params\"][\"disable_lane_boundaries\"],\n",
    "    )\n",
    "else:\n",
    "    raise ValueError(f\"{model_name=} is invalid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare for training\n",
    "Our `EgoDatasetVectorized` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg = cfg[\"train_data_loader\"]\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=train_cfg[\"shuffle\"], batch_size=train_cfg[\"batch_size\"],\n",
    "                              num_workers=train_cfg[\"num_workers\"])\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop\n",
    "Here, we purposely include a barebone training loop. Clearly, many more components can be added to enrich logging and improve performance. Still, the sheer size of our dataset ensures that a reasonable performance can be obtained even with this simple loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_it = iter(train_dataloader)\n",
    "progress_bar = tqdm(range(cfg[\"train_params\"][\"max_num_steps\"]))\n",
    "losses_train = []\n",
    "model.train()\n",
    "torch.set_grad_enabled(True)\n",
    "\n",
    "for _ in progress_bar:\n",
    "    try:\n",
    "        data = next(tr_it)\n",
    "    except StopIteration:\n",
    "        tr_it = iter(train_dataloader)\n",
    "        data = next(tr_it)\n",
    "    # Forward pass\n",
    "    data = {k: v.to(device) for k, v in data.items()}\n",
    "    result = model(data)\n",
    "    loss = result[\"loss\"]\n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    losses_train.append(loss.item())\n",
    "    progress_bar.set_description(f\"loss: {loss.item()} loss(avg): {np.mean(losses_train)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the train loss curve\n",
    "We can plot the train loss against the iterations (batch-wise) to check if our model has converged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(losses_train)), losses_train, label=\"train loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Store the model\n",
    "\n",
    "Let's store the model as a torchscript. This format allows us to re-load the model and weights without requiring the class definition later.\n",
    "\n",
    "**Take note of the path, you will use it later to evaluate your planning model!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "to_save = torch.jit.script(model.cpu())\n",
    "path_to_save = f\"{gettempdir()}/urban_driver.pt\"\n",
    "to_save.save(path_to_save)\n",
    "print(f\"MODEL STORED at {path_to_save}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations in training your very own Urban Driver model!\n",
    "### What's Next\n",
    "\n",
    "Now that your model is trained and safely stored, you can evaluate how it performs in our simulation:\n",
    "\n",
    "\n",
    "### [Closed-loop evaluation](./closed_loop_test.ipynb)\n",
    "In this setting the model **is in full control of the AV's** future movements.\n",
    "\n",
    "## Pre-trained models\n",
    "We provide a collection of pre-trained models, including both our proposed method and several ablations from our paper:\n",
    "- [Urban Driver](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/urban_driver/BPTT.pt);\n",
    "- [Urban Driver without BPTT](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/urban_driver/MS.pt);\n",
    "- [Open Loop](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/urban_driver/OL.pt);\n",
    "- [Open Loop with history](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/urban_driver/OL_HS.pt);\n",
    "\n",
    "To use one of the models simply download the corresponding `.pt` file and load it in the evaluation notebooks - further, please see the comments above regarding necessary config changes."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
